{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sys\n",
    "import tensorflow as tf\n",
    "import copy\n",
    "from sklearn import metrics\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "# depending on the classification model use, we might need to import other packages\n",
    "#from sklearn import svm\n",
    "#from sklearn.ensemble import RandomForestClassifier\n",
    "import sklearn\n",
    "from matplotlib import pyplot as plt\n",
    "import pickle as pkl\n",
    "\n",
    "from datasets import DatasetUCI\n",
    "from envs import LalEnvTargetAccuracy\n",
    "\n",
    "from estimator import Estimator\n",
    "from helpers import Minibatch, ReplayBuffer\n",
    "from dqn import DQN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from Test_AL import check_performance, check_performance_for_figure"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Strategies to test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Random sampling\n",
    "rs = True\n",
    "# Uncertainty sampling\n",
    "us = True\n",
    "# LAL-RL learnt strategy on other datasets\n",
    "rl = True\n",
    "# LAL-RL learnt strategy on the same dataset (another half)\n",
    "rl_notransfer = False\n",
    "# LAL-independent and LAL-iterative strategies\n",
    "lal = False\n",
    "# ALBE strategy that learns a combination of rs, us and quire\n",
    "albe = False\n",
    "# QUIRE strategy\n",
    "quire = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DIRNAME_TRANSFER = './agents/1-australian-logreg-8-to-1/'\n",
    "DIRNAME_NOTRANSFER = ''\n",
    "DIRNAME_RESULTS = './AL_results/test-agent-australian.p'\n",
    "\n",
    "TOLERANCE_LEVEL = 0.98\n",
    "test_dataset_names = ['australian']\n",
    "\n",
    "N_STATE_ESTIMATION = 30\n",
    "SUBSET = -1 # -1 for using all datapoints, 0 for even, 1 for odd\n",
    "SIZE = 100\n",
    "\n",
    "N_JOBS = 1 # can set more if we want to parallelise\n",
    "QUALITY_METHOD = metrics.accuracy_score\n",
    "\n",
    "N_EXPERIMENTS = 500"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Can use different models for classifier\n",
    "<br>\n",
    "`LogisticRegression(n_jobs=N_JOBS)` <br>\n",
    "SVM: <br>\n",
    "`svm.SVC(probability=True)` <br>\n",
    "RF: <br>\n",
    "`RandomForestClassifier(50, oob_score=True, n_jobs=N_JOBS)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "dataset = DatasetUCI(possible_names=test_dataset_names, n_state_estimation=N_STATE_ESTIMATION, subset=SUBSET, size=SIZE)\n",
    "model = LogisticRegression(n_jobs=N_JOBS)\n",
    "env = LalEnvTargetAccuracy(dataset, model, quality_method=QUALITY_METHOD, tolerance_level=TOLERANCE_LEVEL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare AL methods"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Methods for random sampling and uncertainty sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "if rs:\n",
    "    from Test_AL import policy_random\n",
    "if us:\n",
    "    from Test_AL import policy_uncertainty"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load RL model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "if rl:\n",
    "    from Test_AL import policy_rl\n",
    "    tf.reset_default_graph()\n",
    "    # Load the DQN agent from DIRNAME_TRANSFER\n",
    "    agent = DQN(experiment_dir=DIRNAME_TRANSFER,\n",
    "            observation_length=N_STATE_ESTIMATION,\n",
    "            learning_rate=1e-3,\n",
    "            batch_size=32,\n",
    "            target_copy_factor=0.01,\n",
    "            bias_average=0,\n",
    "           )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load RL model with no transfer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if rl_notransfer:\n",
    "    from Test_AL import policy_rl\n",
    "    tf.reset_default_graph()\n",
    "    # Load the DQN agent from DIRNAME_NOTRANSFER\n",
    "    agent = DQN(experiment_dir=DIRNAME_NOTRANSFER,\n",
    "            observation_length=N_STATE_ESTIMATION,\n",
    "            learning_rate=1e-3,\n",
    "            batch_size=32,\n",
    "            target_copy_factor=0.01,\n",
    "            bias_average=0,\n",
    "           )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load LAL models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if lal:\n",
    "    from sklearn.ensemble import RandomForestRegressor\n",
    "    from Test_AL import policy_LAL\n",
    "    \n",
    "    # LAL-INDEPENDENT\n",
    "    # parameters\n",
    "    fn = 'LAL-randomtree-simulatedunbalanced-big.npz'\n",
    "    parameters = {'est': 2000, 'depth': 40, 'feat': 6 }\n",
    "    # load data\n",
    "    filename = '../LALfiles/'+fn\n",
    "    regression_data = np.load(filename)\n",
    "    regression_features = regression_data['arr_0']\n",
    "    regression_labels = regression_data['arr_1']\n",
    "    # build model\n",
    "    print('building lal model..')\n",
    "    lal_model1 = RandomForestRegressor(n_estimators = parameters['est'], max_depth = parameters['depth'], \n",
    "                                     max_features=parameters['feat'], oob_score=True, n_jobs=8)\n",
    "    lal_model1.fit(regression_features, np.ravel(regression_labels))    \n",
    "    print('the model is built!')\n",
    "    print('oob score = ', lal_model1.oob_score_)\n",
    "\n",
    "    # LAL-ITERATIVE\n",
    "    # parameters\n",
    "    fn = 'LAL-iterativetree-simulatedunbalanced-big.npz'\n",
    "    parameters = {'est': 1000, 'depth': 40, 'feat': 6 }\n",
    "    # load data\n",
    "    filename = '../LALfiles/'+fn\n",
    "    regression_data = np.load(filename)\n",
    "    regression_features = regression_data['arr_0']\n",
    "    regression_labels = regression_data['arr_1']\n",
    "    # build model\n",
    "    print('building lal model..')\n",
    "    lal_model2 = RandomForestRegressor(n_estimators = parameters['est'], max_depth = parameters['depth'], \n",
    "                                     max_features=parameters['feat'], oob_score=True, n_jobs=8)\n",
    "    lal_model2.fit(regression_features, np.ravel(regression_labels))    \n",
    "    print('the model is built!')\n",
    "    print('oob score = ', lal_model2.oob_score_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prepare for ALBE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Depending on the classifier, need to use different adapter <br>\n",
    "`SklearnProbaAdapter(LogisticRegression(n_jobs=N_JOBS))` for logistic regression or <br>\n",
    "`SklearnProbaAdapter(svm.SVC(probability=True))` for SVM or <br>\n",
    "`SklearnProbaAdapter(RandomForestClassifier(50, n_jobs=1))` for RF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if albe:\n",
    "    import sys\n",
    "    sys.path.append('./libact/')\n",
    "    from libact.base.dataset import Dataset\n",
    "    from libact.query_strategies import QUIRE, UncertaintySampling, ActiveLearningByLearning\n",
    "    from libact.models import SklearnProbaAdapter\n",
    "    from Test_AL import policy_ALBE\n",
    "    \n",
    "    def reset_albe(dataset, env):\n",
    "        \"\"\"Initialises libact to perform ALBE\"\"\"\n",
    "        adapter = SklearnProbaAdapter(LogisticRegression(n_jobs=1)) \n",
    "        nolabels = np.array(([None] * len(dataset.train_labels)))\n",
    "        libactlabels = nolabels\n",
    "        libactlabels[env.indeces_known] = dataset.train_labels[env.indeces_known]\n",
    "        trn_ds = Dataset(dataset.train_data, libactlabels)\n",
    "        # max number of iterations is needed here\n",
    "        qs = ActiveLearningByLearning(trn_ds, query_strategies=[UncertaintySampling(trn_ds, model=adapter), QUIRE(trn_ds)], T=1000, uniform_sampler=True, model=adapter)\n",
    "        #qs = QUIRE(trn_ds, model=adapter)\n",
    "        #qs = ActiveLearningByLearning(trn_ds, query_strategies=[QUIRE(trn_ds)], T=100, uniform_sampler=True, model=adapter)\n",
    "        return qs, trn_ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prepare for QUIRE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "if quire:\n",
    "    import sys\n",
    "    sys.path.append('./libact/')\n",
    "    from libact.base.dataset import Dataset\n",
    "    from libact.query_strategies import QUIRE\n",
    "    from libact.models import SklearnProbaAdapter\n",
    "    from Test_AL import policy_QUIRE\n",
    "    \n",
    "    def reset_quire(dataset, env):\n",
    "        \"\"\"\"Initialises libact to perform QUIRE\"\"\"\n",
    "        adapter = SklearnProbaAdapter(LogisticRegression(n_jobs=1))\n",
    "        nolabels = np.array(([None] * len(dataset.train_labels)))\n",
    "        libactlabels = nolabels\n",
    "        libactlabels[env.indeces_known] = dataset.train_labels[env.indeces_known]\n",
    "        trn_ds = Dataset(dataset.train_data, libactlabels)\n",
    "        # max number of iterations is needed here\n",
    "        qs = QUIRE(trn_ds, model=adapter)\n",
    "        return qs, trn_ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the experiemnts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Results will be stored in all_results dictionary\n",
    "all_results = {}\n",
    "all_scores_rand = []\n",
    "all_scores_uncert = []\n",
    "all_scores_rl = []\n",
    "all_scores_rl_notransfer = []\n",
    "all_scores_LAL_independant = []\n",
    "all_scores_LAL_iterative = []\n",
    "all_scores_ALBE = []\n",
    "all_scores_QUIRE = []\n",
    "\n",
    "for experiment in range(N_EXPERIMENTS):\n",
    "    print(experiment, end=' ')\n",
    "    # reset the environment\n",
    "    state, next_action_state = env.reset()\n",
    "    # run the experiments\n",
    "    # 1. copy the initial state and environment \n",
    "    # so that all strategies start from the same point            \n",
    "    # 2. done variable indicates when terminal state is reached\n",
    "    # 3. repeat until terminal state is reached\n",
    "    # 4. select an action according to the policy\n",
    "    # to see the prob of selected action: taken_action_state = next_action_state_uncert[:,action]\n",
    "    # 5. make a step in the environment\n",
    "    # 6. keep track of the scores in the episode\n",
    "    if rs:\n",
    "        env_rand = copy.deepcopy(env)\n",
    "        state_rand = copy.deepcopy(state)\n",
    "        done = False\n",
    "        while not(done):\n",
    "            action = policy_random(env_rand.n_actions)\n",
    "            _, _, _, done = env_rand.step(action)\n",
    "        all_scores_rand.append(env_rand.episode_qualities)\n",
    "    if us:\n",
    "        next_action_state_uncert = next_action_state\n",
    "        env_uncert = copy.deepcopy(env)\n",
    "        state_uncert = copy.deepcopy(state)\n",
    "        done = False\n",
    "        while not(done):\n",
    "            action = policy_uncertainty(next_action_state_uncert[0,:])\n",
    "            next_state, next_action_state_uncert, reward, done = env_uncert.step(action)\n",
    "        all_scores_uncert.append(env_uncert.episode_qualities)\n",
    "    if rl:\n",
    "        next_action_state_rl = next_action_state\n",
    "        env_rl = copy.deepcopy(env)\n",
    "        state_rl = copy.deepcopy(state)\n",
    "        done = False\n",
    "        while not(done):\n",
    "            action = policy_rl(agent, state_rl, next_action_state_rl)        \n",
    "            next_state, next_action_state_rl, reward, done = env_rl.step(action)\n",
    "            state_rl = next_state\n",
    "        all_scores_rl.append(env_rl.episode_qualities)\n",
    "    if rl_notransfer:\n",
    "        next_action_state_rl_notransfer = next_action_state\n",
    "        env_rl_notransfer = copy.deepcopy(env)\n",
    "        state_rl_notransfer = copy.deepcopy(state)\n",
    "        done = False\n",
    "        while not(done):\n",
    "            action = policy_rl(agent_notransfer, state_rl_notransfer, next_action_state_rl_notransfer)        \n",
    "            next_state, next_action_state_rl_notransfer, reward, done = env_rl_notransfer.step(action)\n",
    "            state_rl_notransfer = next_state\n",
    "        all_scores_rl_notransfer.append(env_rl_notransfer.episode_qualities)\n",
    "    if lal:\n",
    "        next_action_state_LAL_independant = next_action_state\n",
    "        env_LAL_independant = copy.deepcopy(env)\n",
    "        state_LAL_independant = copy.deepcopy(state)\n",
    "        done = False\n",
    "        while not(done):\n",
    "            env_LAL_independant.for_lal()\n",
    "            action = policy_LAL(dataset, env_LAL_independant, lal_model1)\n",
    "            next_state, next_action_state_LAL_independant, reward, done = env_LAL_independant.step(action)\n",
    "        all_scores_LAL_independant.append(env_LAL_independant.episode_qualities)\n",
    "        next_action_state_LAL_iterative = next_action_state\n",
    "        env_LAL_iterative = copy.deepcopy(env)\n",
    "        state_LAL_iterative = copy.deepcopy(state)\n",
    "        done = False\n",
    "        while not(done):\n",
    "            env_LAL_iterative.for_lal()\n",
    "            action = policy_LAL(dataset, env_LAL_iterative, lal_model2)\n",
    "            next_state, next_action_state_LAL_iterative, reward, done = env_LAL_iterative.step(action)\n",
    "        all_scores_LAL_iterative.append(env_LAL_iterative.episode_qualities)\n",
    "    if albe:\n",
    "        next_action_state_ALBE = next_action_state\n",
    "        env_ALBE = copy.deepcopy(env)\n",
    "        state_ALBE = copy.deepcopy(state)\n",
    "        qs, trn_ds = reset_albe(dataset, env_ALBE)\n",
    "        done = False\n",
    "        while not(done):\n",
    "            action = policy_ALBE(qs, trn_ds, env_ALBE, dataset)\n",
    "            next_state, next_action_state_ALBE, reward, done = env_ALBE.step(action)\n",
    "        all_scores_ALBE.append(env_ALBE.episode_qualities)\n",
    "    if quire:\n",
    "        next_action_state_QUIRE = next_action_state\n",
    "        env_QUIRE = copy.deepcopy(env)\n",
    "        state_QUIRE = copy.deepcopy(state)\n",
    "        qs, trn_ds = reset_quire(dataset, env_QUIRE)\n",
    "        done = False\n",
    "        while not(done):\n",
    "            action = policy_QUIRE(qs, trn_ds, env_QUIRE, dataset)\n",
    "            next_state, next_action_state_QUIRE, reward, done = env_QUIRE.step(action)\n",
    "        all_scores_QUIRE.append(env_QUIRE.episode_qualities)\n",
    "# record the results\n",
    "all_results['all_scores_rand'] = all_scores_rand\n",
    "all_results['all_scores_uncert'] = all_scores_uncert\n",
    "all_results['all_scores_rl'] = all_scores_rl\n",
    "all_results['all_scores_rl_notransfer'] = all_scores_rl_notransfer\n",
    "all_results['all_scores_LAL_independant'] = all_scores_LAL_independant\n",
    "all_results['all_scores_LAL_iterative'] = all_scores_LAL_iterative\n",
    "all_results['all_scores_ALBE'] = all_scores_ALBE\n",
    "all_results['all_scores_QUIRE'] = all_scores_QUIRE\n",
    "pkl.dump(all_results, open(DIRNAME_RESULTS, \"wb\" ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "all_results = pkl.load(open(DIRNAME_RESULTS, \"rb\" ) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "all_scores_rand = all_results['all_scores_rand']\n",
    "all_scores_uncert = all_results['all_scores_uncert']\n",
    "all_scores_rl = all_results['all_scores_rl']\n",
    "all_scores_rl_notransfer = all_results['all_scores_rl_notransfer']\n",
    "all_scores_LAL_independant = all_results['all_scores_LAL_independant']\n",
    "all_scores_LAL_iterative = all_results['all_scores_LAL_iterative']\n",
    "all_scores_ALBE = all_results['all_scores_ALBE']\n",
    "all_scores_QUIRE = all_results['all_scores_QUIRE']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check the results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Compute the mean duration, it's std, median and max"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "max_duration = 0\n",
    "if rs:\n",
    "    print(\"Random\")\n",
    "    all_scores_rand, all_durations_rand = check_performance(all_scores_rand)\n",
    "    max_duration = max(max_duration, max(all_durations_rand))\n",
    "if us:\n",
    "    print(\"Uncertainty\")\n",
    "    all_scores_uncert, all_durations_uncert = check_performance(all_scores_uncert)\n",
    "    max_duration = max(max_duration, max(all_durations_uncert))\n",
    "if rl:\n",
    "    print(\"RL\")\n",
    "    all_scores_rl, all_durations_rl = check_performance(all_scores_rl)\n",
    "    max_duration = max(max_duration, max(all_durations_rl))\n",
    "if rl_notransfer:\n",
    "    print(\"RL without transfer\")\n",
    "    all_scores_rl_notransfer, all_durations_rl_notransfer = check_performance(all_scores_rl_notransfer)\n",
    "    max_duration = max(max_duration, max(all_durations_rl_notransfer))\n",
    "if lal:\n",
    "    print(\"LAL independant\")\n",
    "    all_scores_LAL_independant, all_durations_LAL_independant = check_performance(all_scores_LAL_independant)\n",
    "    max_duration = max(max_duration, max(all_durations_LAL_independant))\n",
    "    print(\"LAL iterative\")\n",
    "    all_scores_LAL_iterative, all_durations_LAL_iterative = check_performance(all_scores_LAL_iterative)\n",
    "    max_duration = max(max_duration, max(all_durations_LAL_iterative))\n",
    "if albe:\n",
    "    print(\"ALBE\")\n",
    "    all_scores_ALBE, all_durations_ALBE = check_performance(all_scores_ALBE)\n",
    "    max_duration = max(max_duration, max(all_durations_ALBE))\n",
    "if quire:\n",
    "    print(\"QUIRE\")\n",
    "    all_scores_QUIRE, all_durations_QUIRE = check_performance(all_scores_QUIRE)\n",
    "    max_duration = max(max_duration, max(all_durations_QUIRE))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Compute the relative scores that can be used to plot the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "scores_relative_rand = check_performance_for_figure(all_scores_rand, max_duration)\n",
    "scores_relative_uncert = check_performance_for_figure(all_scores_uncert, max_duration)\n",
    "scores_relative_rl = check_performance_for_figure(all_scores_rl, max_duration)\n",
    "scores_relative_rl_notransfer = check_performance_for_figure(all_scores_rl_notransfer, max_duration)\n",
    "scores_relative_LAL_independant = check_performance_for_figure(all_scores_LAL_independant, max_duration)\n",
    "scores_relative_LAL_iterative = check_performance_for_figure(all_scores_LAL_iterative, max_duration)\n",
    "scores_relative_ALBE = check_performance_for_figure(all_scores_ALBE, max_duration)\n",
    "scores_relative_QUIRE = check_performance_for_figure(all_scores_QUIRE, max_duration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Plot the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20,10))\n",
    "if rs:\n",
    "    m_line = np.mean(scores_relative_rand, axis=0)\n",
    "    var_line = np.var(scores_relative_rand, axis=0)\n",
    "    plt.plot(m_line, linewidth=2.0, label = 'random', color='k')\n",
    "    plt.fill_between(range(np.size(m_line)), m_line - var_line, m_line + var_line, color='k', alpha=0.2)\n",
    "if us:\n",
    "    m_line = np.mean(scores_relative_uncert, axis=0)\n",
    "    var_line = np.var(scores_relative_uncert, axis=0)\n",
    "    plt.plot(m_line, linewidth=2.0, label = 'uncertainty', color='b')\n",
    "    plt.fill_between(range(np.size(m_line)), m_line - var_line, m_line + var_line, color='b', alpha=0.2)\n",
    "if rl:\n",
    "    m_line = np.mean(scores_relative_rl, axis=0)\n",
    "    var_line = np.var(scores_relative_rl, axis=0)\n",
    "    plt.plot(m_line, linewidth=2.0, label = 'rl', color='red')\n",
    "    plt.fill_between(range(np.size(m_line)), m_line - var_line, m_line + var_line, color='red', alpha=0.2)\n",
    "if rl_notransfer:\n",
    "    m_line = np.mean(scores_relative_rl_notransfer, axis=0)\n",
    "    var_line = np.var(scores_relative_rl_notransfer, axis=0)\n",
    "    plt.plot(m_line, linewidth=2.0, label = 'rl no transfer', color='red')\n",
    "    plt.fill_between(range(np.size(m_line)), m_line - var_line, m_line + var_line, color='red', alpha=0.2)\n",
    "if lal:\n",
    "    m_line = np.mean(scores_relative_LAL_independant, axis=0)\n",
    "    var_line = np.var(scores_relative_LAL_independant, axis=0)\n",
    "    plt.plot(m_line, linewidth=2.0, label = 'LAL-independant', color='c')\n",
    "    plt.fill_between(range(np.size(m_line)), m_line - var_line, m_line + var_line, color='c', alpha=0.2)    \n",
    "    m_line = np.mean(scores_relative_LAL_iterative, axis=0)\n",
    "    var_line = np.var(scores_relative_LAL_iterative, axis=0)\n",
    "    plt.plot(m_line, linewidth=2.0, label = 'LAL-iterative', color='m')\n",
    "    plt.fill_between(range(np.size(m_line)), m_line - var_line, m_line + var_line, color='m', alpha=0.2)\n",
    "if albe:\n",
    "    m_line = np.mean(scores_relative_ALBE, axis=0)\n",
    "    var_line = np.var(scores_relative_ALBE, axis=0)\n",
    "    plt.plot(m_line, linewidth=2.0, label = 'ALBE', color='g')\n",
    "    plt.fill_between(range(np.size(m_line)), m_line - var_line, m_line + var_line, color='g', alpha=0.2)\n",
    "if quire:\n",
    "    m_line = np.mean(scores_relative_QUIRE, axis=0)\n",
    "    var_line = np.var(scores_relative_QUIRE, axis=0)\n",
    "    plt.plot(m_line, linewidth=2.0, label = 'QUIRE', color='y')\n",
    "    plt.fill_between(range(np.size(m_line)), m_line - var_line, m_line + var_line, color='y', alpha=0.2)\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
